# dqn_utils.py
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.optim as optim

# =============== Common utility functions ===============
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# =============== Data preprocessing functions ===============

def load_dataset(csv_file):
    df = pd.read_csv(csv_file)
    df["state"] = df["state"].apply(eval)
    state_dim = len(df["state"].iloc[0])

    state_mat = np.vstack(df["state"].values)
    for i in range(state_dim):
        df[f"symptom_{i}"] = state_mat[:, i]
    df.drop(columns=["state"], inplace=True)

    for j in range(state_dim):
        df[f"next_symptom_{j}"] = np.nan

    grouped = df.groupby("pid")
    for pid, traj in grouped:
        for i in range(len(traj) - 1):
            next_row = traj.iloc[i + 1][[f"symptom_{j}" for j in range(state_dim)]].values
            for j in range(state_dim):
                df.at[traj.index[i], f"next_symptom_{j}"] = next_row[j]
        for j in range(state_dim):
            df.at[traj.index[-1], f"next_symptom_{j}"] = traj.iloc[-1][f"symptom_{j}"]

    states = df[[f"symptom_{i}" for i in range(state_dim)]].values.astype(np.float32)
    next_states = df[[f"next_symptom_{i}" for i in range(state_dim)]].values.astype(np.float32)
    actions = df["action"].values.astype(np.int64)

    # ⚠️ Use new_reward column
    rewards = df["new_reward"].values.astype(np.float32)
    dones = df["done"].astype(int).values
    n_actions = df["action"].max() + 1

    return states, next_states, actions, rewards, dones, state_dim, n_actions

# =============== DQN Network ===============
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
        )
    def forward(self, x):
        return self.net(x)

# =============== DQN Training Function ===============
def train_dqn(seed, states, next_states, actions, rewards, dones, state_dim, n_actions,
              n_epochs=10, batch_size=64, gamma=0.99):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    policy_net = DQN(state_dim, n_actions).to(device)
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    dataset_size = len(states)
    loss_history = []

    for epoch in range(n_epochs):
        idxs = np.random.permutation(dataset_size)
        total_loss, n_batches = 0, 0
        for i in range(0, dataset_size, batch_size):
            batch_idx = idxs[i:i+batch_size]
            s = torch.tensor(states[batch_idx], dtype=torch.float32).to(device)
            a = torch.tensor(actions[batch_idx], dtype=torch.int64).to(device)
            r = torch.tensor(rewards[batch_idx], dtype=torch.float32).to(device)
            ns = torch.tensor(next_states[batch_idx], dtype=torch.float32).to(device)
            d = torch.tensor(dones[batch_idx], dtype=torch.float32).to(device)

            q_values = policy_net(s).gather(1, a.unsqueeze(1)).squeeze()
            with torch.no_grad():
                max_next_q = policy_net(ns).max(1)[0]
                target = r + gamma * max_next_q * (1 - d)

            loss = loss_fn(q_values, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1
        loss_history.append(total_loss / n_batches)
    return policy_net, loss_history

# =============== Policy Evaluation Function ===============
def evaluate_policy(env, policy_net, n_episodes=200):
    device = next(policy_net.parameters()).device
    rewards, censored = [], 0
    for ep in range(n_episodes):
        state = env.reset()
        done, ep_reward = False, 0
        while not done:
            s = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                q_vals = policy_net(s)
                action = q_vals.argmax(1).item()
            state, reward, done, info = env.step(action)
            ep_reward += reward
            if done and info["delta"] == 0:
                censored += 1
        rewards.append(ep_reward)
    return np.mean(rewards), censored / n_episodes